# ------------------------------------------------------------------------------
# Runs GBMs for doctors' simplified risk model
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=50000]" bash 01_fit_gbms.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(930)

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(optparse)
library(glue)
library(data.table)
library(tidyverse)
library(Matrix)
library(dplyr)
library(xgboost)

u <- modules::use(here("lib", "util.R"))
temp <- here(
  "code", "06_physician_boundedness", "02_behavioral_gbms", "temp"
)

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
cohort_fp <- file.path(paths$modeling$dir, "cohorts", "random")
x <- readRDS(paths$features$train_tested)
ids <- readRDS(file.path(cohort_fp, "train_cohort_tested.rds"))

# Subset -----------------------------------------------------------------------
message("Subsetting to tested...")
keep_obs <- which(
  ids$test_010_day & !ids$excl_flag_c_int &
    !ids$excl_flag_chronic & !ids$excl_flag_death
)
x <- x[keep_obs, ]
ids <- ids[keep_obs, ]
mean_stent <- mean(filter(ids, test_010_day)$stent_or_cabg_010_day)

# Parameters -------------------------------------------------------------------
message("Fetching parameters...")
tuning_path <- file.path(
  paths$modeling$dir, "tuning", "random", "all", "gbm__stent_or_cabg_010_day__tested.rds"
)
tuning <- readRDS(tuning_path)

message("Choosing parameters...")
best_params <- tuning %>%
  unnest() %>%
  top_n(1, -logloss) %>%
  top_n(1, max_depth) %>%
  top_n(1, subsample) %>%
  top_n(1, colsample_bytree) %>%
  select(-logloss, everything())

best_params$eta <- 0.3
best_params$max_depth <- 5
best_params$colsample_bytree <- 0.5

print(best_params)

# Fit Model --------------------------------------------------------------------
num_models <- 100
message("Fitting ", num_models, "models...")
gbm_list <- list()
gbm_dsmall_list <- list()
for (i in 1:num_models) {
  for (j in 2:4) {
    params_dsmall <- copy(best_params)
    params_dsmall$max_depth <- i
    gbm_dsmall <- xgboost(
      params = params_dsmall,
      data = x,
      label = ids$stent_or_cabg_010_day,
      nthread = n_distinct(ids$train_fold),
      nrounds = 1L,
      verbose = 0
    )
    model_name <- glue("model_{i}_d{j}")
    gbm_dsmall_list[[model_name]] <- gbm_dsmall
  }
  gbm <- xgboost(
    params = best_params,
    data = x,
    label = ids$stent_or_cabg_010_day,
    nthread = n_distinct(ids$train_fold),
    nrounds = 500L,
    verbose = 0
  )
  model_name <- glue("model_{i}")
  gbm_list[[model_name]] <- gbm
}

# Save -------------------------------------------------------------------------
message("Saving...")
write_rds(gbm_list, file.path(temp, "gbms__stent_or_cabg_010_day__tested.rds"))
write_rds(gbm_dsmall_list, file.path(temp, "gbms__dsmall__stent_or_cabg_010_day__tested.rds"))

message("Done.")
